# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) OpenMMLab. All rights reserved.
import gc

import numpy as np

gc.disable()
import warnings

# display once
warnings.filterwarnings("ignore", message=".*torch.cpu.amp.autocast.*")
from torch.nn import functional as F
import argparse
import math
import os
import time
from datetime import timedelta, datetime
from functools import partial
import random
import re
import torch
import sys
from collections import OrderedDict
import shutil
import torch.distributed as dist
from mmengine.runner import set_random_seed
from mmengine.utils import mkdir_or_exist, get_git_hash
from mmengine.utils.dl_utils import collect_env
from PIL import Image
import json
from torch.optim import AdamW
from torch.distributed.checkpoint.state_dict import (StateDictOptions,
                                                     get_model_state_dict,
                                                     get_state_dict, set_state_dict)
from torch.optim.lr_scheduler import CosineAnnealingLR, LambdaLR
from xtuner._lite.parallel import (MetaStateful, get_dp_mesh, get_fsdp_mesh,
                                   get_sp_mesh, get_tp_mesh, get_world_mesh,
                                   setup_parallel)
from torch.distributed._composable.fsdp import MixedPrecisionPolicy

from xtuner._lite.parallel.megatron import megatron_parallelize
from xtuner._lite.accelerate import (LoadWoInit, dispatch_hf_code, packed_sequence,
                                     profile_time_and_memory)
from xtuner._lite.parallel.fsdp import clip_grad_norm_
import torch.distributed.checkpoint as dcp
import copy
from xtuner._lite.datasets.internvl2.dataset import build_train_dataloader, BaseOrigDataset
from xtuner._lite.datasets.utils import move_data_to_device, apply_exif_orientation
from xtuner._lite import get_repo_git_info
from janus.models import MultiModalityCausalLM, VLChatProcessor
from rigx.datasets import build_dataset
from torch import multiprocessing as mp
from torch.utils.tensorboard import SummaryWriter
from concurrent.futures import wait
from xtuner._lite import get_logger, get_torch_device_module, get_device
from torchvision.utils import make_grid

logger = get_logger()
DEVICE = get_device()
DEVICE_MODULE = get_torch_device_module()


def parse_args():
    parser = argparse.ArgumentParser(description='Train LLM')

    model_args = parser.add_argument_group('model', 'Model Related Settings')
    model_args.add_argument(
        '--model', help='repo id or local path of the model')
    parser.add_argument(
        '--liger', action='store_true', help='use liger kernel')
    parser.add_argument('--use-hsdp', action='store_true')
    parser.add_argument('--gradient-sync-after-accumulate', action='store_true')
    parser.add_argument('--tensorboard', action='store_true')
    model_args.add_argument(
        '--freeze-style',
        choices=['mode1', 'mode2', 'mode3'],
        help="Not updating parameters")
    model_args.add_argument(
        '--dtype',
        default='auto',
        choices=['fp16', 'bf16', 'auto'],
        help=("the dtype of the model forward. When set to 'auto', it will "
              'automatically determine whether bf16 is available, '
              'prioritizing the use of bf16.'))
    model_args.add_argument(
        '--selective-recompute',
        default=1.0,
        type=float,
        help=('the ratio of re-computation for transforemer layers. '
              'The maximum is 1; the larger the value, the less memory '
              'required for training. The default is 1, meaning all layers '
              'need to be re-computated.'))
    parser.add_argument(
        '--reshard-after-forward', action='store_true', help='')
    model_args.add_argument('--sp-size', type=int, default=1, help='')
    model_args.add_argument(
        '--tp-size',
        default=1,
        type=int,
        help="tp size")
    data_args = parser.add_argument_group('data', 'Dataset Related Settings')
    data_args.add_argument(
        '--datasets',
        help=('repo id or local path or dir of the datasets. For repo ids, '
              'the `dset-sources` needs to be appropriately set to '
              '`modelscope` or `huggingface`. For local dir, all json and '
              'jsonl files will be loaded by default. The type of loaded '
              'files can be controlled by setting `dset-file-type`'))
    data_args.add_argument(
        '--dset-cache-dir',
        help=('the cache dir of the loaded datasets. When the `datasets` is '
              'set, the loaded datasets will be cached to this dir. If the '
              '`datasets` are not set, the cached dataset in this dir will be '
              'loaded.'))
    data_args.add_argument('--dset-pack', action='store_true')
    data_args.add_argument('--concat-before-pack', action='store_true')
    data_args.add_argument('--group-by-length', action='store_true')
    data_args.add_argument('--group-by-modality-length', action='store_true')
    data_args.add_argument(
        '--max-length',
        type=int,
        default=4096,
        help=('the maximum length of each piece of data, any excess will be '
              'truncated.'))
    data_args.add_argument(
        '--pack-max-length',
        type=int,
        default=8192,
        help='the maximum length of each pack of data')
    data_args.add_argument(
        '--max-keep-ckpts',
        type=int,
        default=1,
        help='the maximum number of checkpoints to keep.')
    data_args.add_argument(
        '--num-workers',
        type=int,
        default=1,
        help='how many subprocesses to use for data loading.')
    data_args.add_argument(
        '--pack-len-type',
        default='total_block',
        choices=['total_block', 'max_block'],
        help='')
    data_args.add_argument(
        '--pack-extra-buffer-size',
        type=int,
        default=1000,
        help='')
    optim_args = parser.add_argument_group('optim', 'Optim Related Settings')
    optim_args.add_argument(
        '--mirco-batch-size',
        type=int,
        default=1,
        help='batch size for each forward + backward pass')
    optim_args.add_argument(
        '--global-batch-size',
        type=int,
        default=16,
        help='batch size for each parameter update')

    optim_args.add_argument(
        '--lr', default=4e-5, type=float, help='learning rate.')
    optim_args.add_argument(
        '--lr-min', default=0, type=float, help='min learning rate.')
    optim_args.add_argument(
        '--wd', default=0.01, type=float, help='weight decay.')
    optim_args.add_argument(
        '--max-grad-norm', default=1, type=float, help='gradient clipping')
    optim_args.add_argument(
        '-e', '--epochs', default=1, type=int, help='total training epochs.')
    optim_args.add_argument(
        '--warmup-ratio',
        default=0.03,
        type=float,
        help=('the proportion of training steps for learning rate warm-up in '
              'relation to the total training steps.'))
    parser.add_argument(
        '--work-dir',
        default='work_dirs',
        help='the dir to save logs and checkpoints')
    parser.add_argument(
        '--checkpoint-interval',
        default=0.25,
        type=float,
        help=('how many steps to save a checkpoint; it can be a floating '
              'point number less than 1, or an integer greater than or equal '
              "to 1. When it's a floating point, it will be multiplied by the "
              'total number of training steps.'))
    parser.add_argument(
        '--checkpoint-drop-optimizer',
        action='store_true',
        help=('only model parameters are saved when saving a checkpoint. '
              'This can significantly reduce the size of checkpoint files, '
              'but the saved checkpoints cannot be resumed.'))
    parser.add_argument('--gc-interval', default=500, type=int)
    parser.add_argument(
        '--hf-interval',
        default=-1,
        type=int,
        help=('how many steps to save a hf model; it can be a floating '
              'point number less than 1, or an integer greater than or equal '
              "to 1. When it's a floating point, it will be multiplied by the "
              'total number of training steps.'))
    parser.add_argument(
        '--log-interval', default=1, type=int, help='log interval')
    parser.add_argument(
        '--resume', action='store_true', help='resume from the last checkpoint')
    parser.add_argument(
        '--resume-from',
        type=str,
        default=None,
        help='specify checkpoint path to be resumed from.')
    parser.add_argument(
        '--seed', type=int, default=0, help='random seed for the training')
    args = parser.parse_args()
    return args


def is_interval(step, total_steps, interval):
    return (step + 1) % interval == 0 or (step + 1) == total_steps


def log_format(rank, debug=False):
    formatter = f'[XTuner][RANK {rank}]'
    formatter += '[{time:YYYY-MM-DD HH:mm:ss}][<level>{level}</level>]'

    if debug:
        formatter += '[<cyan>{name}</cyan>:'
        formatter += '<cyan>{function}</cyan>:'
        formatter += '<cyan>{line}</cyan>]'

    formatter += ' <level>{message}</level>'
    return formatter


def check_args(args):
    if args.resume_from and args.resume is False:
        args.resume = True
    if args.resume is True and args.resume_from is None:
        # find last checkpoint
        ckpt_dirs = [d for d in os.listdir(args.work_dir) if
                     os.path.isdir(os.path.join(args.work_dir, d)) and d.startswith('ckpt-')]
        if len(ckpt_dirs) > 0:
            ckpt_dirs.sort(reverse=True)
            is_success = False
            for ckpt_dir in ckpt_dirs:
                if os.path.exists(os.path.join(args.work_dir, ckpt_dir, '.metadata')):
                    args.resume_from = os.path.join(args.work_dir, ckpt_dir)
                    is_success = True
                    break
                else:
                    os.system(f'rm -rf {os.path.join(args.work_dir, ckpt_dir)}')
            if is_success is False:
                logger.warning('Did not find last_checkpoint to be resumed. training from scratch.')
                args.resume = False
        else:
            logger.warning('Did not find last_checkpoint to be resumed. training from scratch.')
            args.resume = False

    if args.resume:
        assert not args.checkpoint_drop_optimizer, '`resume` and `checkpoint_drop_optimizer` cannot be set at the same time.'

    dp_size = get_dp_mesh().size()
    world_size = get_world_mesh().size()
    if args.global_batch_size < dp_size or args.global_batch_size % dp_size:
        raise ValueError(f'The `global_batch_size`({args.global_batch_size}) '
                         f'should be divisible by the world_size{world_size}.')

    if (args.global_batch_size / dp_size) % args.mirco_batch_size:
        raise ValueError(f'The `global_batch_size`({args.global_batch_size}) '
                         f'should be divisible by the world_size{world_size}*'
                         f'`mirco_batch_size`({args.mirco_batch_size})')

    if args.group_by_length and args.group_by_modality_length:
        logger.warning('if you set both `group_by_length` and `group_by_modality_length`,'
                       ' the `group_by_modality_length` will be used.')

    if args.sp_size > 1 and args.mirco_batch_size > 1:
        raise NotImplementedError('Not support mirco_batch_size>1 when sp_size')

    if args.gradient_sync_after_accumulate is True:
        if args.reshard_after_forward is True:
            args.gradient_sync_after_accumulate = False
            logger.warning(
                '`gradient_sync_after_accumulate` and `reshard_after_forward` cannot be set at the same time. '
                'force to set `gradient_sync_after_accumulate` to False.')


def set_logger_envs(args):
    rank = get_world_mesh().get_rank()
    world_size = get_world_mesh().size()

    mkdir_or_exist(args.work_dir)

    log_file = os.path.join(args.work_dir, f'rank{rank}.log')

    logger.remove()
    # Change the log format printed in the terminal
    lvl = 'INFO'
    logger.add(sys.stderr, level=lvl, format=log_format(rank))
    # Change the format saved in the log file
    logger.add(log_file, format=log_format(rank), backtrace=True, catch=True)

    logger.info(args)
    if rank == 0:
        env = collect_env()
        import transformers

        import xtuner
        env['Transformers'] = transformers.__version__
        env['XTuner'] = f'{xtuner.__version__}+{get_git_hash(digits=6)}'
        runtime_env = OrderedDict()
        runtime_env.update(env)
        runtime_env['Seed'] = args.seed
        runtime_env['World Size'] = world_size

        branch, commit_id, remote_url = get_repo_git_info(os.path.dirname(os.path.abspath(__file__)))
        if branch is not None:
            runtime_env['xpuyu_branch'] = branch
            runtime_env['xpuyu_commit_id'] = commit_id
            runtime_env['xpuyu_remote_url'] = remote_url

        branch, commit_id, remote_url = get_repo_git_info(os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                                                       '../xtuner'))
        if branch is not None:
            runtime_env['xtuner_branch'] = branch
            runtime_env['xtuner_commit_id'] = commit_id
            runtime_env['xtuner_remote_url'] = remote_url

        branch, commit_id, remote_url = get_repo_git_info(os.path.join(os.path.dirname(os.path.abspath(__file__)),
                                                                       '../janus'))
        if branch is not None:
            runtime_env['janus_branch'] = branch
            runtime_env['janus_commit_id'] = commit_id
            runtime_env['janus_remote_url'] = remote_url

        runtime_env_info = '\n    ' + '\n    '.join(
            f'{k}: {v}' for k, v in runtime_env.items())
        dash_line = '-' * 60

        logger.info('\n' + dash_line + '\nRuntime environment:' +
                    runtime_env_info + '\n' + dash_line + '\n')

        shutil.copy(__file__, args.work_dir)


def record_tensorboard(tensorboard_kargs, queue):
    writer = SummaryWriter(**tensorboard_kargs)
    i = 0
    while True:
        if not queue.empty():
            tag, value, step = queue.get()
            if tag == 'over':
                writer.close()
                break

            if '/img' in tag:
                writer.add_image(tag, value, step)
            else:
                writer.add_scalar(tag, value, step)
            i += 1
        else:
            time.sleep(0.01)


class SummaryWriterWrapper(SummaryWriter):
    def __init__(
            self,
            # tensorboard args
            log_dir=None,
            comment="",
            purge_step=None,
            max_queue=10,
            flush_secs=120,
            filename_suffix="",
            queue_size=3000,
            only_rank0=True,
    ):
        if only_rank0 and dist.get_rank() != 0:
            self.queue = None
            self.thread = None
        else:
            tensorboard_kargs = dict(
                log_dir=log_dir,
                comment=comment,
                purge_step=purge_step,
                max_queue=max_queue,
                flush_secs=flush_secs,
                filename_suffix=filename_suffix,
            )
            ctx = mp.get_context('spawn')
            self.queue = ctx.Queue(maxsize=queue_size)
            self.thread = ctx.Process(
                target=record_tensorboard, args=(tensorboard_kargs, self.queue)
            )
            self.thread.start()

    def qsize(self):
        if self.queue is not None:
            return self.queue.qsize()
        else:
            return 0

    def add_scalar(
            self,
            tag,
            scalar_value,
            global_step=None,
            walltime=None,
            new_style=False,
            double_precision=False,
            reduce_op=None,
    ):
        if reduce_op is not None:
            scalar_value = torch.tensor(scalar_value).cuda()
            dist.all_reduce(scalar_value, op=reduce_op)
            scalar_value = scalar_value.item()
        if self.thread is not None:
            self.queue.put((tag, scalar_value, global_step))

    def add_optimize_info(self, grad_norm, inf_nan_skip_batches, lr, steps):
        self.add_scalar("optimize/grad_norm", grad_norm, global_step=steps)
        self.add_scalar("optimize/lr", lr, global_step=steps)
        self.add_scalar(
            "optimize/inf_nan_skip_batches",
            inf_nan_skip_batches,
            global_step=steps,
        )

    def add_speed_info(self, tgs, e2e_tgs, step):
        self.add_scalar("speed/tgs", tgs, step, reduce_op=None)
        self.add_scalar("speed/e2e_tgs", e2e_tgs, step, reduce_op=None)
        self.add_scalar("speed/tb_qsize", self.qsize(), step, reduce_op=None)

    def close(self):
        if self.queue is not None:
            self.queue.put(('over', 0, 0))


def resume(args, fsdp_model, optimizer, warmup_scheduler, cosine_scheduler, start_step, total_steps,
           inf_nan_skip_batches):
    logger.info(f'[Resume] Resume from {args.resume_from}')
    _options = StateDictOptions(
        cpu_offload=True, ignore_frozen_params=True)
    (shard_model_state_dict,
     shard_optimizer_state_dict) = get_state_dict(
        fsdp_model, optimizer, options=_options)
    meta_stateful = MetaStateful(step=start_step, total_steps=total_steps, inf_nan_skip_batches=inf_nan_skip_batches)
    state_dict = {
        'model': shard_model_state_dict,
        'optimizer': shard_optimizer_state_dict,
        'meta_stateful': meta_stateful,
        'warmup_scheduler': warmup_scheduler,
        'cosine_scheduler': cosine_scheduler
    }
    # inplace state_dict
    dcp.load(
        state_dict=state_dict,
        checkpoint_id=args.resume_from,
    )

    _options = StateDictOptions(
        cpu_offload=True, strict=False)
    set_state_dict(
        fsdp_model,
        optimizer,
        model_state_dict=state_dict["model"],
        optim_state_dict=state_dict["optimizer"],
        options=_options
    )

    start_step = meta_stateful['step']
    inf_nan_skip_batches = meta_stateful['inf_nan_skip_batches']
    logger.info(f'[Resume] start_step to {start_step}')
    return start_step, inf_nan_skip_batches


def save_ckpt(args, step, total_steps, inf_nan_skip_batches, fsdp_model, rank0_model, warmup_scheduler,
              cosine_scheduler, optimizer,
              max_keep_ckpts, save_hf_ckpt_names, save_pt_ckpt_names, tokenizer, processor, future,
              save_pt=True,
              save_hf=True):
    # torch.cuda.empty_cache()
    # torch.cuda.reset_peak_memory_stats()
    digits = len(str(abs(total_steps)))
    work_dir = args.work_dir

    ckpt_id = f'{(step + 1):0{digits}}-of-{total_steps:0{digits}}'
    ckpt_dir = os.path.join(work_dir, f'ckpt-{ckpt_id}')
    hf_dir = os.path.join(work_dir, f'hf-{ckpt_id}')

    rank = dist.get_rank()
    if save_hf:
        with profile_time_and_memory('[HF Checkpoint]'):
            from torch.distributed._tensor import DTensor

            if rank == 0:
                llm_state_dict = {}

            with torch.no_grad():
                for name, param in fsdp_model.state_dict().items():
                    if isinstance(param, DTensor):
                        full_param = param.full_tensor().cpu()
                    else:
                        full_param = param.cpu()

                    if rank == 0:
                        llm_state_dict[name] = full_param

            if rank == 0:
                rank0_model.load_state_dict(llm_state_dict)
                rank0_model.save_pretrained(hf_dir)
                if tokenizer is not None:
                    tokenizer.save_pretrained(hf_dir)
                if processor is not None:
                    processor.save_pretrained(hf_dir)

                save_hf_ckpt_names.append(hf_dir)
                if len(save_hf_ckpt_names) > max_keep_ckpts:
                    remove_hf_ckpt_name = save_hf_ckpt_names.pop(0)
                    os.system(f'rm -rf {remove_hf_ckpt_name}')
        dist.barrier()

    if save_pt:
        if future is not None:
            wait([future])

        if args.checkpoint_drop_optimizer:
            logger.warning('[Checkpoint] The saved checkpoint cannot be '
                           'resumed. If you want to save a resumable '
                           'checkpoint, please remove '
                           '`--checkpoint-drop-optimizer` '
                           'from the command.')
        else:
            xtuner_load_timeout = timedelta(minutes=60)
            group_gloo = dist.new_group(backend='gloo', timeout=xtuner_load_timeout)
            with profile_time_and_memory('[PT Checkpoint]'):
                if dist.get_rank() == 0:
                    mkdir_or_exist(ckpt_dir)
                dist.barrier()

                # FSDP cannot be saved via torch.save
                # Refer to https://pytorch.org/tutorials/recipes/distributed_checkpoint_recipe.html  # noqa: E501
                _options = StateDictOptions(
                    cpu_offload=True, ignore_frozen_params=True)
                (shard_model_state_dict,
                 shard_optimizer_state_dict) = get_state_dict(
                    fsdp_model, optimizer, options=_options)
                meta_stateful = MetaStateful(step=step + 1, total_steps=total_steps,
                                             inf_nan_skip_batches=inf_nan_skip_batches)
                state_dict = {
                    'model': shard_model_state_dict,
                    'optimizer': shard_optimizer_state_dict,
                    'meta_stateful': meta_stateful,
                    'warmup_scheduler': warmup_scheduler.state_dict(),
                    'cosine_scheduler': cosine_scheduler.state_dict()
                }
                future = dcp.async_save(state_dict, checkpoint_id=ckpt_dir, process_group=group_gloo)

                def send_to_oss_and_remove(future):
                    # send to oss and remove local file
                    # TODO: send to oss

                    if dist.get_rank() == 0:
                        save_pt_ckpt_names.append(ckpt_dir)
                        if len(save_pt_ckpt_names) > max_keep_ckpts:
                            remove_pt_ckpt_name = save_pt_ckpt_names.pop(0)
                            os.system(f'rm -rf {remove_pt_ckpt_name}')
                    # print('============send_to_oss_and_remove callback==================')

                future.add_done_callback(send_to_oss_and_remove)
    return future


class LazyJanusDataset(BaseOrigDataset):
    def __init__(self, data_name, data, model_name,
                 max_length=4096,
                 group_by_length=False,
                 pack_data=False, pack_data_cache_dir=None):

        self.processor = VLChatProcessor.from_pretrained(model_name)
        self.chat_template = dict(system='{system}\n\n',
                                  user='User: {user}\n\nAssistant:',
                                  assistant='{assistant}<｜end▁of▁sentence｜>')

        self.default_num_tokens = data.get('default_num_tokens', -1)

        self.use_und_gen = False
        logger.info(f'============Use und_gen: {self.use_und_gen}===============')
        super().__init__(data_name, data, self.chat_template,
                         tokenizer=self.processor.tokenizer,
                         max_length=max_length,
                         group_by_length=group_by_length,
                         pack_data=pack_data,
                         pack_data_cache_dir=pack_data_cache_dir)

    def calc_group_len(self):
        group_length = []
        print('Calculating the length of text data...')
        conv2length_text = {}

        for data_item in self.raw_data:
            if self._is_jsonl:
                data_item = json.loads(data_item)

            num_images = 0
            for temp in data_item['conversations']:
                num_images += len(temp.get('images', []))

            conversations = '\n'.join([temp['value'] for temp in data_item['conversations']])
            str_length = len(conversations)
            if str_length not in conv2length_text:
                token_length = self.tokenizer.encoder(conversations).size(1)
                conv2length_text[str_length] = token_length
            else:
                token_length = conv2length_text[str_length]

            if num_images > 0:
                token_length += num_images * self.processor.num_image_tokens
                group_length.append(token_length)
            else:
                group_length.append(-token_length)
        print('Finished calculating the length of text data...')
        return group_length

    def calc_packing_info(self):
        if self.default_num_tokens > 0:
            num_samples = len(self.raw_data)
            num_tokens = np.array([self.default_num_tokens] * num_samples)
            return num_tokens
        else:
            super().calc_packing_info()

    def pre_tokenize_fn_for_pack(self, data_item):
        if self._is_jsonl:
            data_item = json.loads(data_item)

        num_images = 0
        for temp in data_item['conversations']:
            num_images += len(temp.get('images', []))

        if num_images > 0:
            num_tokens = self.multi_modal_get_item(data_item, pack_data=True)
        else:
            raise NotImplementedError
        return {'num_tokens': num_tokens}

    def multi_modal_get_item(self, data_item, pack_data=False):
        if pack_data:
            ret = self.process_text(data_item['conversations'], media_type='image')
            if self.use_und_gen:
                # 额外多一个无条件图片生成 token
                # bos_id + image_start_id + image_end_id + eos_id
                return len(ret['input_ids']) + (self.processor.num_image_tokens + 4)
            else:
                return len(ret['input_ids'])

        image_path_in = data_item['conversations'][0]['images'][0]
        image_path_out = data_item['conversations'][1]['images'][0]

        images = []
        for image_path in [image_path_in, image_path_out]:
            image_path = os.path.join(self.root, image_path)
            image = Image.open(image_path).convert('RGB')
            image = apply_exif_orientation(image)
            images.append(image)

        image_inputs = self.processor.image_processor(images, return_tensors="pt")
        pixel_values = image_inputs["pixel_values"]
        ret = self.process_text(data_item['conversations'], media_type='image')
        # 由于生成的图片一定在后面，因此可以直接 hard code 实现
        ret = dict(
            input_ids=ret['input_ids'],
            labels=ret['labels'],
            image_flags=torch.tensor([1], dtype=torch.long),
            gen_image_bounds=torch.tensor([len(ret['input_ids']) - 2 - self.processor.num_image_tokens,
                                           len(ret['input_ids']) - 2], dtype=torch.long),
            pixel_values=pixel_values,
            num_tokens=[len(ret['input_ids'])],
            num_img_tokens=[self.processor.num_image_tokens * 2],
        )

        if self.use_und_gen:
            # 额外多一个无条件图片生成 token
            new_gen_input_ids = [self.tokenizer.bos_token_id] + [
                self.processor.image_start_id] + self.processor.num_image_tokens * [self.processor.image_id] + [
                                    self.processor.image_end_id] + [self.tokenizer.eos_token_id]
            new_gen_labels = [self.tokenizer.bos_token_id] + [self.processor.image_start_id] + \
                             [-100] * self.processor.num_image_tokens + [-100] + [self.tokenizer.eos_token_id]
            new_ret = dict(
                input_ids=new_gen_input_ids,
                labels=new_gen_labels,
                num_tokens=[len(new_gen_input_ids)],
                num_img_tokens=[self.processor.num_image_tokens],
            )
            ret['und_gen'] = new_ret
        return ret

    def process_text(self, conversations, media_type='image', image_grids=None):
        assert len(conversations) % 2 == 0, f'Invalid conversation length: {len(conversations)}'

        input_ = ''
        out_conversation = []
        for msg in conversations:
            if msg['from'] == 'human':
                input_ += msg['value'].strip()
            elif msg['from'] == 'gpt':
                temp_dict = {'input': input_, 'output': msg['value'].strip(), 'images': msg.get('images', [])}
                out_conversation.append(temp_dict)
                input_ = ''
            else:
                raise NotImplementedError(f'Unsupported message type: {msg}')

        input_ids, labels = [], []

        replace_image_str = f'{self.processor.image_start_tag}'
        replace_image_str += f'{self.processor.image_tag}' * self.processor.num_image_tokens
        replace_image_str += f'{self.processor.image_end_tag}'

        for i, single_turn_conversation in enumerate(out_conversation):
            input_ = single_turn_conversation.get('input', '')
            if input_ is None:
                input_ = ''
            input_ = self.chat_template['user'].format(user=input_)

            if i == 0:
                if media_type == 'image':
                    assert '<image>' in input_, f'Image placeholder not found in the first conversation: {input_}'
                    assert input_.count(
                        '<image>') == 1, f'Multiple image placeholders found in the first conversation: {input_}'

                    input_ = input_.replace('<image>', replace_image_str)

                input_encode = self.tokenizer.encode(input_)
            else:
                input_encode = self.tokenizer.encode(input_)
                # remove bos_id
                input_encode = input_encode[1:]

            input_ids += input_encode
            labels += [-100] * len(input_encode)

            output_text = single_turn_conversation.get('output', '')
            output_encode = self.chat_template['assistant'].format(assistant=output_text)
            if 'images' in single_turn_conversation:
                assert '<image>' in output_encode, f'Image placeholder not found in the first conversation: {output_encode}'
                assert len(single_turn_conversation['images']) == 1
                output_encode = output_encode.replace('<image>', replace_image_str)
            output_encode = self.tokenizer.encode(output_encode)  # alway add bos_id
            # remove bos_id
            input_ids += output_encode[1:]

            output_encode_ = copy.deepcopy(output_encode[1:])
            # 注意这个代码，第一个 token 的下一个 token 其实是图片，因此 llm head 处无法算 loss
            # 第二个 token 本身就是最后一个 token，本身就不算 loss
            output_encode_[output_encode_.index(self.processor.image_start_id) + 1] = -100
            output_encode_[output_encode_.index(self.processor.image_end_id)] = -100

            labels += output_encode_

        if len(input_ids) > self.max_length:
            input_ids = input_ids[:self.max_length]
            labels = labels[:self.max_length]
            logger.info(
                f'Warning: input_ids length({len(input_ids)}) '
                f'is longer than max_length, cut to {self.max_length}')
        return {'input_ids': input_ids, 'labels': labels}

    def pure_text_get_item(self, data_item, pack_data=False):
        ret = self.process_text(data_item['conversations'], media_type='text')

        if pack_data:
            # 额外多一个无条件图片生成 token
            # 1:bos_id, 1: image_start_id
            return len(ret['input_ids']) + (1 + 1 + self.processor.num_image_tokens)

        image = Image.new('RGB', (224, 224), (255, 255, 255))
        image_inputs = self.processor.image_processor(image, return_tensors="pt")
        pixel_values = image_inputs["pixel_values"].repeat(2, 1, 1, 1)

        ret = dict(
            input_ids=ret['input_ids'],
            labels=ret['labels'],
            image_flags=torch.tensor([0], dtype=torch.long),
            pixel_values=pixel_values,
            num_tokens=[len(ret['input_ids'])],
            num_img_tokens=[0],
        )
        return ret

    def __getitem__(self, i):
        i = i % len(self.raw_data)
        while True:
            try:
                data_item = self.raw_data[i]
                if self._is_jsonl:
                    data_item = json.loads(data_item)

                num_images = 0
                for conversation in data_item['conversations']:
                    num_images += len(conversation.get('images', []))

                if num_images > 0:
                    ret = self.multi_modal_get_item(data_item)
                else:
                    raise NotImplementedError
                break
            except Exception as e:
                print(f'Exception: {e} of {self.data_name}', flush=True)
                i = random.randint(0, len(self.raw_data) - 1)
        return ret


def packing_collate(features, pack_batch=True, pad_id=0, sp_size=1):
    _features = []
    for ins in features:
        if isinstance(ins, list):
            _features.extend(ins)
        else:
            _features.append(ins)
    features = _features

    input_ids = []
    labels = []
    pixel_values = []
    num_tokens = []
    num_img_tokens = []
    image_flags = []
    gen_image_bounds = []

    cum_len = 0
    for data in features:
        input_ids.append(torch.LongTensor(data['input_ids']))
        labels.append(torch.LongTensor(data['labels']))
        num_tokens.extend(data['num_tokens'])
        num_img_tokens.extend(data['num_img_tokens'])
        pixel_values.append(data['pixel_values'])
        image_flags.append(data['image_flags'])
        if 'gen_image_bounds' in data:
            gen_image_bounds.append(cum_len + data['gen_image_bounds'])
        cum_len += len(data['input_ids'])

        if 'und_gen' in data:
            input_ids.append(torch.LongTensor(data['und_gen']['input_ids']))
            labels.append(torch.LongTensor(data['und_gen']['labels']))
            num_tokens.extend(data['und_gen']['num_tokens'])
            num_img_tokens.extend(data['und_gen']['num_img_tokens'])
            cum_len += len(data['und_gen']['input_ids'])

    num_tokens = torch.IntTensor(num_tokens)
    num_img_tokens = torch.IntTensor(num_img_tokens)

    input_ids = torch.cat(input_ids, dim=0).unsqueeze(0)
    labels = torch.cat(labels, dim=0).unsqueeze(0)
    pixel_values = torch.cat(pixel_values, dim=0)
    image_flags = torch.cat(image_flags, dim=0)

    if len(gen_image_bounds) > 0:
        gen_image_bounds = torch.cat(gen_image_bounds, dim=0)

    data_dict = {
        'input_ids': input_ids,
        'labels': labels,
        'pixel_values': pixel_values,
        'image_flags': image_flags,
        'gen_image_bounds': gen_image_bounds,
        'num_tokens': num_tokens,
        'num_img_tokens': num_img_tokens,
    }
    return data_dict


def build_llava_model(args, dtype=torch.float32, device='cpu'):
    with torch.device(device):
        with LoadWoInit():
            model = MultiModalityCausalLM.from_pretrained(args.model)

        model.to(dtype)
        model.train()

        if args.freeze_style == 'mode1':
            model.requires_grad_(False)
            model.eval()
            model.language_model.lm_head.requires_grad_(True)
            model.gen_head.requires_grad_(True)
            model.aligner.requires_grad_(True)
            model.gen_aligner.requires_grad_(True)

        elif args.freeze_style == 'mode2':
            # 尽可能都放开
            model.gen_vision_model.requires_grad_(False)
            model.gen_vision_model.eval()
            model.vision_model.requires_grad_(False)
            model.vision_model.eval()

    for module in model.modules():
        for p_name, param in module.named_parameters(recurse=False):
            if param.requires_grad:
                param_fp32 = torch.nn.Parameter(param.to(dtype=torch.float32))
                setattr(module, p_name, param_fp32)
    return model


def vlm_train(args):
    if args.liger:
        raise NotImplementedError('Liger is not supported in this version.')

    setup_parallel(tp_size=args.tp_size, sp_size=args.sp_size)
    set_random_seed(args.seed)

    dp_mesh = get_dp_mesh()
    tp_mesh = get_tp_mesh()
    sp_mesh = get_sp_mesh()
    fsdp_mesh = get_fsdp_mesh()  # dp_size * sp_size
    world_mesh = get_world_mesh()  # dp_size * sp_size * tp_size

    dp_size = dp_mesh.size()
    sp_size = sp_mesh.size()
    tp_size = tp_mesh.size()

    rank = world_mesh.get_rank()

    set_logger_envs(args)
    check_args(args)
    with profile_time_and_memory('[Dataset & Dataloader]'):
        ds_collections = json.loads(open(args.datasets).read())
        _datasets = []
        for name, _data in ds_collections.items():
            _dataset = LazyJanusDataset(name, _data, args.model,
                                        max_length=args.max_length,
                                        group_by_length=args.group_by_length,
                                        pack_data=args.dset_pack,
                                        pack_data_cache_dir=args.dset_cache_dir)
            if dist.get_rank() == 0:
                logger.info(f'[Dataset] (Original) {name}: {len(_dataset)} samples.')
            _datasets.append(_dataset)
        train_dataset = build_dataset(args, _datasets)
        logger.warning(f'{dist.get_rank()} ===== End of all dataset =====')
        train_dataloader = build_train_dataloader(args, train_dataset, packing_collate)

    args.dtype = 'bf16'
    dtype = torch.bfloat16
    with profile_time_and_memory('[Model]'):
        meta_model = build_llava_model(args, dtype=dtype, device='meta')
        dispatch_hf_code(meta_model)
        if dist.get_rank() == 0:
            logger.info(meta_model)

        timeout = timedelta(
            minutes=int(os.getenv('XTUNER_DATASET_TIMEOUT', default=45)))
        group = dist.new_group(backend='gloo', timeout=timeout)
        if rank == 0:
            logger.info(f'=====[Build CPU Model]=======')
            rank0_model = build_llava_model(args, dtype=dtype, device='cpu')
        else:
            rank0_model = None
        dist.monitored_barrier(group=group, timeout=timeout)

        mp_policy = MixedPrecisionPolicy(param_dtype=dtype, reduce_dtype=dtype)
        fsdp_model = megatron_parallelize(meta_model,
                                          rank0_model,
                                          fsdp_mesh,
                                          tp_mesh=tp_mesh,
                                          mp_policy=mp_policy,
                                          reshard_after_forward=True if args.reshard_after_forward else False,
                                          freeze_style=args.freeze_style)
        if dist.get_rank() == 0:
            logger.info(fsdp_model)

    requried_grad_params = [
        param for param in fsdp_model.parameters() if param.requires_grad
    ]
    requried_grad_name = [name for name, param in fsdp_model.named_parameters() if param.requires_grad]
    if rank == 0:
        logger.info(f'[Optimizer] {requried_grad_name}')

    optimizer = AdamW(
        requried_grad_params, lr=args.lr, weight_decay=args.wd, fused=False)

    max_memory = get_torch_device_module().max_memory_allocated()
    logger.info('[Train] Begin Train Loop. The current GPU memory is '
                f'{(max_memory / 1024 ** 3):.1f}GB')

    global_batch_size = args.global_batch_size
    mirco_batch_size = args.mirco_batch_size

    # `iter` means once forward+backward
    # `step` means once optimizer step
    # `per_step_iters` means gradient accumulative counts
    per_step_iters = global_batch_size // mirco_batch_size // dp_size
    per_epoch_iters = len(train_dataloader)
    per_epoch_steps = math.ceil(per_epoch_iters / per_step_iters)
    logger.info(f'[Optimizer] Global batch size: {global_batch_size}, Gradient accumulative counts: {per_step_iters}')
    total_epochs = args.epochs
    total_steps = per_epoch_steps * total_epochs

    if args.checkpoint_interval == -1:
        checkpoint_interval = total_steps
    elif args.checkpoint_interval < 1:
        checkpoint_interval = int(total_steps * args.checkpoint_interval)
    else:
        checkpoint_interval = int(args.checkpoint_interval)

    if args.hf_interval == -1:
        hf_interval = total_steps
    elif args.hf_interval < 1:
        hf_interval = int(total_steps * args.hf_interval)
    else:
        hf_interval = int(args.hf_interval)

    warmup_steps = int(args.warmup_ratio * total_steps)

    def warmup_fn(x):
        return x / warmup_steps if x < warmup_steps else 1

    warmup_scheduler = LambdaLR(optimizer, warmup_fn)

    cosine_scheduler = CosineAnnealingLR(
        optimizer, T_max=total_steps - warmup_steps, eta_min=args.lr_min)

    start_step = 0
    inf_nan_skip_batches = 0
    if args.resume:
        start_step, inf_nan_skip_batches = resume(args, fsdp_model, optimizer, warmup_scheduler,
                                                  cosine_scheduler, start_step, total_steps,
                                                  inf_nan_skip_batches)

    start_train_t = time.time()
    torch.cuda.empty_cache()
    torch.cuda.reset_peak_memory_stats()
    max_memory = torch.cuda.max_memory_allocated()

    save_hf_ckpt_names = []
    save_pt_ckpt_names = []
    ckpt_dirs = [os.path.join(args.work_dir, d) for d in os.listdir(args.work_dir) if
                 os.path.isdir(os.path.join(args.work_dir, d)) and d.startswith('ckpt-')]
    if len(ckpt_dirs) > 0:
        ckpt_dirs.sort()
        save_pt_ckpt_names = ckpt_dirs

    hf_dirs = [os.path.join(args.work_dir, d) for d in os.listdir(args.work_dir) if
               os.path.isdir(os.path.join(args.work_dir, d)) and d.startswith('hf-')]
    if len(hf_dirs) > 0:
        hf_dirs.sort()
        save_hf_ckpt_names = hf_dirs

    max_keep_ckpts = args.max_keep_ckpts
    if max_keep_ckpts <= 0:
        max_keep_ckpts = 100000000

    if rank == 0:
        if args.dset_pack:
            logger.info(f'======= Using soft packing style. =========')
        if args.sp_size > 1:
            assert args.dset_pack, 'Only support soft packing with sp_size > 1.'
            logger.info(
                f'======= Using SP mode. sp_ulysess:{args.sp_size // args.ring_size}, sp_ring:{args.ring_size}======')

        logger.info('[Train] Begin Train Loop. The current GPU memory is '
                    f'{(max_memory / 1024 ** 3):.1f}GB')
        logger.info('The FSDP adopts a lazy design, so the first iteration will be slow.')
        if args.liger:
            logger.info('====== use liger kernel =====')

    processor = VLChatProcessor.from_pretrained(args.model)
    future = None

    if args.tensorboard:
        tbwriter = SummaryWriterWrapper(log_dir=args.work_dir + f'/rank_{dist.get_rank()}',
                                        only_rank0=True)
    else:
        tbwriter = None

    torch.cuda.empty_cache()

    max_memory = torch.cuda.max_memory_allocated()
    logger.info('[Train] Begin Train Loop. The current GPU memory is '
                f'{(max_memory / 1024 ** 3):.1f}GB')

    total_consumed_tokens = 0
    time_used_by_val_and_save_ckpt = 0

    for step in range(start_step, total_steps):
        torch.cuda.reset_peak_memory_stats()

        if is_interval(step + 1, total_steps, args.gc_interval):
            # torch.cuda.empty_cache()
            gc.collect()

        epoch = step // per_epoch_steps
        epoch_inner_step = step % per_epoch_steps
        if epoch_inner_step == 0 or step == start_step:
            # For the first step of each epoch, the data order needs to be
            # readjusted.
            # Or after resuming, for the first step, the dataloader needs to
            # be adjusted to the position before resume.
            # train_dataloader.sampler.set_epoch(epoch, inner_step)
            train_dataloader.sampler.set_epoch(epoch, epoch_inner_step * per_step_iters)
            data_iterator = iter(train_dataloader)

        if step <= warmup_steps:
            warmup_scheduler.step()
            cur_lr = warmup_scheduler.get_lr()[0]
        else:
            cosine_scheduler.step()
            cur_lr = cosine_scheduler.get_lr()[0]

        torch.cuda.reset_peak_memory_stats()

        step_loss = 0
        step_und_loss = 0
        step_gen_loss = 0

        step_consumed_tokens = 0
        step_consumed_img_tokens = 0
        _data_start_t = time.time()

        step_data_list = [next(data_iterator) for _ in range(per_step_iters)]
        rank_grad_tokens = 0
        for _iter in range(per_step_iters):
            _iter_data = step_data_list[_iter]
            _iter_labels = _iter_data['labels'][:, 1:]
            rank_grad_tokens += (_iter_labels >= 0).sum()
        rank_grad_tokens = rank_grad_tokens.to(DEVICE)
        dist.all_reduce(rank_grad_tokens)
        global_grad_tokens = rank_grad_tokens / sp_size / tp_size
        step_data_time = time.time() - _data_start_t

        for _iter in range(per_step_iters):
            data = step_data_list[_iter]
            data = move_data_to_device(data)

            input_ids = data['input_ids']
            labels = data['labels']
            pixel_values = data['pixel_values']
            image_flags = data['image_flags']
            gen_image_bounds = data['gen_image_bounds']

            num_tokens = data.pop('num_tokens')
            num_img_tokens = data.pop('num_img_tokens')

            packed_ctx = packed_sequence(num_tokens, enable=True)

            with packed_ctx:

                gen_loss, und_loss = fsdp_model(input_ids, labels, pixel_values, image_flags, gen_image_bounds,
                                                image_id=processor.image_id)

                loss = (gen_loss + und_loss) / global_grad_tokens * dp_size

                if args.gradient_sync_after_accumulate and per_step_iters > 1:
                    is_accumulating = _iter < per_step_iters - 1
                    fsdp_model.set_is_last_backward(not is_accumulating)
                    fsdp_model.set_requires_gradient_sync(not is_accumulating)

                loss.backward()

            step_loss += loss.item()
            step_und_loss += und_loss.item() / global_grad_tokens * dp_size
            step_gen_loss += gen_loss.item() / global_grad_tokens * dp_size

            step_consumed_tokens += num_tokens.sum() / sp_size / tp_size
            step_consumed_img_tokens += num_img_tokens.sum() / sp_size / tp_size

        grad_norm = clip_grad_norm_(requried_grad_params, fsdp_mesh, args.max_grad_norm)
        if grad_norm.isnan() or grad_norm.isinf():
            inf_nan_skip_batches += 1
            logger.info(f"The grad norm is NaN={grad_norm.isnan()} or Inf={grad_norm.isinf()}, skip this batch.")
            optimizer.zero_grad()
        else:
            optimizer.step()
            optimizer.zero_grad()

        step_text_tokens = step_consumed_tokens - step_consumed_img_tokens
        step_img_tokens = step_consumed_img_tokens
        step_time = time.time() - _data_start_t
        eta = step_time * (total_steps - step)
        eta = timedelta(seconds=int(eta))
        tgs = int(step_consumed_tokens / step_time)
        max_memory = torch.cuda.max_memory_allocated()

        total_consumed_tokens += step_consumed_tokens
        end2end_tgs = int(total_consumed_tokens / (time.time() - start_train_t - time_used_by_val_and_save_ckpt))

        if tbwriter is not None:
            tensorboard_start_time = time.time()
            tbwriter.add_optimize_info(grad_norm.detach().clone(), inf_nan_skip_batches, cur_lr, step + 1)
            tbwriter.add_speed_info(tgs, end2end_tgs, step + 1)
            tbwriter.add_scalar('loss/total_loss', step_loss, step + 1)

            # if is_interval(step, total_steps, args.log_interval):
            #     show_data=step_data_list[0]
            #     num_tokens = show_data.pop('num_tokens').tolist()
            #     sequences = torch.split(show_data['input_ids'], num_tokens, dim=1)

            #     pixel_values = show_data['pixel_values']
            #     if pixel_values.shape[0] > 16:
            #         pixel_values = pixel_values[:16,...]
            #         sequences = sequences[:16]
            # output_str=''
            # for sequ in sequences:
            #     decode_str = processor.tokenizer.decode(sequ[0])
            #     modified_str = re.sub(rf'({processor.image_tag})+', processor.image_tag, decode_str)
            #     output_str += modified_str + '\n========\n'

            # tbwriter.add_scalar('train_img/text', output_str, step + 1)
            # pixel_values = make_grid(pixel_values, nrow=4)
            # tbwriter.add_scalar('train_img/img', pixel_values, step + 1)

            tensorboard_time = time.time() - tensorboard_start_time
        else:
            tensorboard_time = -1

        if is_interval(step, total_steps, args.log_interval):
            logger.info(
                f'[Train] (Epoch {epoch}) Step {step + 1}/{total_steps}  '  # noqa: E501
                f'lr: {cur_lr:.6f}  loss: {step_loss:.3f}  '
                f'grad_norm: {grad_norm:.2f}  '
                f'und_loss: {step_und_loss:.3f}, gen_loss: {step_gen_loss:.3f}  '
                f'max_memory: {(max_memory / 1024 ** 3):.1f}GB  '
                f'text_tokens: {step_text_tokens}  '
                f'image_tokens: {step_img_tokens}  '
                f'tgs: {tgs} e2e_tgs: {end2end_tgs} data_time: {step_data_time:.2f}s  '
                f'time: {step_time:.2f}s  '
                f'tb_time: {tensorboard_time:.2f}s  '
                f'eta: {eta}')

            time_before_save = time.time()
            if is_interval(step, total_steps, hf_interval):
                future = save_ckpt(args, step, total_steps, inf_nan_skip_batches, fsdp_model, rank0_model,
                                   warmup_scheduler,
                                   cosine_scheduler,
                                   optimizer, max_keep_ckpts, save_hf_ckpt_names, save_pt_ckpt_names,
                                   processor.tokenizer,
                                   processor,
                                   future, save_pt=False)

            if is_interval(step, total_steps, checkpoint_interval):
                future = save_ckpt(args, step, total_steps, inf_nan_skip_batches, fsdp_model, rank0_model,
                                   warmup_scheduler,
                                   cosine_scheduler,
                                   optimizer, max_keep_ckpts, save_hf_ckpt_names, save_pt_ckpt_names,
                                   processor.tokenizer,
                                   processor,
                                   future, save_hf=False)
            time_used_by_val_and_save_ckpt += time.time() - time_before_save

    if tbwriter is not None:
        tbwriter.close()

    if future is not None:
        wait([future])

    train_cost_time = time.time() - start_train_t
    m, s = divmod(train_cost_time, 60)
    h, m = divmod(m, 60)
    d, h = divmod(h, 24)
    logger.info("[Train] Cost: %d day, %d:%d:%d" % (d, h, m, s))
    # ------------------------    Training  End  ---------------------------- #
    dist.destroy_process_group()


if __name__ == '__main__':
    args = parse_args()
    vlm_train(args)
